{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f457372b",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/finetuning/knowledge/finetune_retrieval_aug.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aff20d1-c522-4e9f-9ed8-3d70c4788601",
   "metadata": {},
   "source": [
    "# Fine-tuning with Retrieval Augmentation\n",
    "\n",
    "Here we try fine-tuning an LLM with retrieval augmentation, as referenced from the RA-DIT paper: https://arxiv.org/abs/2310.01352.\n",
    "\n",
    "For a given (query, response) input/output example, we retrieve the k text chunks with a retriever (the quality of the retriever doesn't have to be perfect, and in fact can be primitive). We then format each query with individually retrieved context, to create k examples (query + context_i, response) for fine-tuning.\n",
    "\n",
    "The idea is to allow the LLM to better use background knowledge to synthesize a correct answer, or to synthesize a correct answer even in the absence of good background knowledge. This will enable the LLM to reason from its priors a bit better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71c4993c-d527-4840-8182-4bfe03bbdb69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "from llama_index import ServiceContext\n",
    "from llama_index.llms import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3492e87-6453-4ad1-b066-57f9d4b5fdb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea6541f7-eeab-4e43-bb53-69bd7ffddec8",
   "metadata": {},
   "source": [
    "## Setup + Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a936769a-4ec7-4c1e-b782-9973fd0c0236",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: data: File exists\n"
     ]
    }
   ],
   "source": [
    "!mkdir data && wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8432f89c-8fbc-4afc-98c3-ac49d6adb176",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_hub.file.pdf.base import PDFReader\n",
    "from llama_hub.file.unstructured.base import UnstructuredReader\n",
    "from llama_hub.file.pymu_pdf.base import PyMuPDFReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12efa907-2cf3-46f2-b66f-79f93737c761",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyMuPDFReader()\n",
    "docs0 = loader.load(file_path=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "914ef39a-8f5b-4cc9-af62-0bef2c6689d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import Document\n",
    "\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "metadata = {\n",
    "    \"paper_title\": \"Llama 2: Open Foundation and Fine-Tuned Chat Models\"\n",
    "}\n",
    "docs = [Document(text=doc_text, metadata=metadata)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c58d458d-71b0-4c1d-bddc-bc2ada565c7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(docs[0].get_content())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1663f6d2-ba1b-4184-89c8-1e73cfc6c0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.callbacks import CallbackManager\n",
    "\n",
    "callback_manager = CallbackManager([])\n",
    "\n",
    "gpt_35_context = ServiceContext.from_defaults(\n",
    "    llm=OpenAI(model=\"gpt-3.5-turbo-0613\", temperature=0.3),\n",
    "    callback_manager=callback_manager,\n",
    ")\n",
    "gpt_4_context = ServiceContext.from_defaults(\n",
    "    llm=OpenAI(model=\"gpt-4-0613\", temperature=0.3),\n",
    "    callback_manager=callback_manager,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21c71cb2-4ae3-4bd2-b67e-77c99ae11110",
   "metadata": {},
   "source": [
    "### Get Nodes, Setup Vector Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65012a61-1faf-4ad1-a7b5-02d81af9b3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.node_parser import SentenceSplitter\n",
    "from llama_index import VectorStoreIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d1b721-ce41-41af-aed9-630d1be5a24c",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_parser = SentenceSplitter()\n",
    "nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8bfe8e7-3a8e-48fd-b80b-3df752627dec",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_index = VectorStoreIndex(nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0ffa721-ea43-4a82-bfe3-ad7a159f3b6e",
   "metadata": {},
   "source": [
    "## Generate Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4d50c8e-6059-4b52-83a0-733e494c8e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import (\n",
    "    DatasetGenerator,\n",
    "    QueryResponseDataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15e97ab8-b919-40ae-988a-f4c38d5c4dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_context = ServiceContext.from_defaults(\n",
    "    llm=OpenAI(model=\"gpt-4\", temperature=0), callback_manager=callback_manager\n",
    ")\n",
    "dataset_generator = DatasetGenerator(\n",
    "    nodes[:39],\n",
    "    service_context=eval_context,\n",
    "    show_progress=True,\n",
    "    num_questions_per_chunk=20,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f50adaf-6312-4bcd-bf3d-f088e537b21e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset = await dataset_generator.agenerate_dataset_from_nodes(num=60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cf77278-9e13-4179-a3e4-c83ecce4c8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset.save_json(\"data_rag/qa_pairs.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb070aa7-1e95-4f74-aeff-3af8fc205c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional\n",
    "eval_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d1acb2c-14f0-4909-a56e-19c6d3b137a3",
   "metadata": {},
   "source": [
    "#### Option 2: Load from existing data \n",
    "\n",
    "If you were already using the fine-tuning knowledge notebook, you can use that instead. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda45602-3938-4e45-ad19-487da00af591",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "\n",
    "# load data in from .jsonl format\n",
    "def load_dataset_from_other_nb(path):\n",
    "    fp = open(path, \"r\")\n",
    "    qr_pairs = []\n",
    "    for line in fp:\n",
    "        qa_pair = json.loads(line)\n",
    "        query_str = qa_pair[\"query\"]\n",
    "        response_str = qa_pair[\"response\"]\n",
    "        qr_pairs.append((query_str, response_str))\n",
    "\n",
    "    return qr_pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14b88c84-dd5b-4670-aa68-dfb3b5a1d516",
   "metadata": {},
   "outputs": [],
   "source": [
    "qr_pairs = load_dataset_from_other_nb(\"data/qa_pairs_2.jsonl\")\n",
    "eval_dataset = QueryResponseDataset.from_qr_pairs(qr_pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c34c07-52b2-4727-9dae-78e6fad76165",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eba36e6-81c9-4bf8-a6e8-520ef97b72c2",
   "metadata": {},
   "source": [
    "### For each Datapoint, Fetch Retrieved Context with a Retriever\n",
    "\n",
    "For each (question, response) pair, fetch the top-k context with a retriever.\n",
    "\n",
    "For each pair, we create k (question + context_i, response) new pairs, where we format each input with the QA prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3e99114-2ca8-4c73-9d91-862f8f88090f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import VectorStoreIndex\n",
    "from llama_index.prompts import PromptTemplate\n",
    "\n",
    "qa_prompt_tmpl_str = (\n",
    "    \"Context information is below.\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"Given the context information and not prior knowledge, \"\n",
    "    \"answer the query.\\n\"\n",
    "    \"Query: {query_str}\\n\"\n",
    "    \"Answer: \"\n",
    ")\n",
    "qa_prompt_tmpl = PromptTemplate(qa_prompt_tmpl_str)\n",
    "\n",
    "vector_retriever = vector_index.as_retriever(similarity_top_k=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d544bd5a-ae7b-45d8-9dbe-fc02b64b62f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "\n",
    "def augment_data_with_retrieval(dataset, retriever, separate_context=False):\n",
    "    data_list = dataset.qr_pairs\n",
    "    new_data_list = []\n",
    "    for query_str, response in tqdm(data_list):\n",
    "        retrieved_nodes = retriever.retrieve(query_str)\n",
    "        retrieved_txts = [n.get_content() for n in retrieved_nodes]\n",
    "        if separate_context:\n",
    "            for retrieved_txt in retrieved_txts:\n",
    "                fmt_query_str = qa_prompt_tmpl.format(\n",
    "                    query_str=query_str, context_str=retrieved_txt\n",
    "                )\n",
    "                new_data_list.append((fmt_query_str, response))\n",
    "        else:\n",
    "            context_str = \"\\n\\n\".join(retrieved_txts)\n",
    "            fmt_query_str = qa_prompt_tmpl.format(\n",
    "                query_str=query_str, context_str=context_str\n",
    "            )\n",
    "            new_data_list.append((fmt_query_str, response))\n",
    "    return new_data_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "128dc6b3-b6b0-44d9-9e4b-860498b2f454",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_qr_pairs = augment_data_with_retrieval(\n",
    "    eval_dataset, vector_retriever, separate_context=False\n",
    ")\n",
    "new_eval_dataset = QueryResponseDataset.from_qr_pairs(new_qr_pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1dfe68-8475-47ef-95ef-9caefa23a1b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_eval_dataset.save_json(\"data_rag/qa_pairs_ra.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce1829f4-3e94-482c-a4be-0447c5239671",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_eval_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs_ra.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83f40ad6-2824-4242-a50c-e595f4b4e368",
   "metadata": {},
   "source": [
    "### Split into Training and Validation Sets\n",
    "\n",
    "We split into training and validation sets.\n",
    "\n",
    "**NOTE**: We shuffle the data before splitting. This helps ensure that the training data has coverage throughout the document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "813574f3-db3a-4ed6-89e1-fac1a6be9194",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import random\n",
    "\n",
    "\n",
    "def split_train_val(dataset, train_split=0.7):\n",
    "    lines = dataset.qr_pairs\n",
    "\n",
    "    # shuffle the lines to make sure that the \"train questions\" cover most fo the context\n",
    "    shuffled_lines = deepcopy(lines)\n",
    "    random.shuffle(shuffled_lines)\n",
    "\n",
    "    split_idx = int(train_split * len(shuffled_lines))\n",
    "    train_lines = shuffled_lines[:split_idx]\n",
    "    val_lines = shuffled_lines[split_idx:]\n",
    "\n",
    "    return train_lines, val_lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9ca273-fc66-49bc-8a00-6823906485fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_lines, val_lines = split_train_val(new_eval_dataset, train_split=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feabd5ab-5902-4720-91fb-764878921dc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = QueryResponseDataset.from_qr_pairs(train_lines)\n",
    "val_dataset = QueryResponseDataset.from_qr_pairs(val_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6cace48-70e9-401b-b29d-ab5c8ff173f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset.save_json(\"data_rag/qa_pairs_train.json\")\n",
    "val_dataset.save_json(\"data_rag/qa_pairs_val.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "094e9373-717f-4456-a4fd-f13582181bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs_train.json\")\n",
    "val_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs_val.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f44ccf-c334-42ae-939c-c64cc4a164e3",
   "metadata": {},
   "source": [
    "### Format into Training Data\n",
    "\n",
    "Format into training data for OpenAI's finetuning endpoints.\n",
    "\n",
    "**NOTE**: We don't use our `OpenAIFinetuningHandler` because that logs the full input prompt including context as the user message. Here we just want to log the query as the user message, because we want to fine-tune gpt-3.5-turbo to \"bake in knowledge\" into the fine-tuned model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bebe2432-5867-41ef-ac08-cb5cf55d421c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_openai_data(dataset, out_path):\n",
    "    # out_fp = open(\"data_rag/qa_pairs_openai.jsonl\", \"w\")\n",
    "    out_fp = open(out_path, \"w\")\n",
    "    # TODO: try with different system prompts\n",
    "    system_prompt = {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": (\n",
    "            \"You are a helpful assistant helping to answer questions about the\"\n",
    "            \" Llama 2 paper.\"\n",
    "        ),\n",
    "    }\n",
    "    train_qr_pairs = dataset.qr_pairs\n",
    "    for line in train_qr_pairs:\n",
    "        query, response = line\n",
    "        user_prompt = {\"role\": \"user\", \"content\": query}\n",
    "        assistant_prompt = {\"role\": \"assistant\", \"content\": response}\n",
    "        out_dict = {\n",
    "            \"messages\": [system_prompt, user_prompt, assistant_prompt],\n",
    "        }\n",
    "        out_fp.write(json.dumps(out_dict) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7baa2821-71de-4887-b9ed-53f474d41118",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_openai_data(train_dataset, \"data_rag/qa_pairs_openai.jsonl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06875df5-9e53-4f2e-a300-0d16881adfde",
   "metadata": {},
   "source": [
    "## Fine-tune the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0268eeb4-1a72-47b0-8fa8-4dd707edec69",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import OpenAIFinetuneEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b04dc433-ca15-4076-a71c-e35df18ff0a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine = OpenAIFinetuneEngine(\n",
    "    \"gpt-3.5-turbo\",\n",
    "    \"data_rag/qa_pairs_openai.jsonl\",\n",
    "    # start_job_id=\"<start-job-id>\"  # if you have an existing job, can specify id here\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b36f698f-591c-44c9-9e16-9328102daef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b332bc43-304c-4883-ad33-b0c1610e5572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<FineTuningJob fine_tuning.job id=ftjob-Rue4Yti7XpddPFYB6CnZadGo at 0x2cf346750> JSON: {\n",
       "  \"object\": \"fine_tuning.job\",\n",
       "  \"id\": \"ftjob-Rue4Yti7XpddPFYB6CnZadGo\",\n",
       "  \"model\": \"gpt-3.5-turbo-0613\",\n",
       "  \"created_at\": 1696407754,\n",
       "  \"finished_at\": 1696411978,\n",
       "  \"fine_tuned_model\": \"ft:gpt-3.5-turbo-0613:llamaindex::85sXTAx1\",\n",
       "  \"organization_id\": \"org-1ZDAvajC6v2ZtAP9hLEIsXRz\",\n",
       "  \"result_files\": [\n",
       "    \"file-9EY2Wj1Gb2lzcZi1PMqVnIpt\"\n",
       "  ],\n",
       "  \"status\": \"succeeded\",\n",
       "  \"validation_file\": null,\n",
       "  \"training_file\": \"file-0iLbjiXwv33i1eZQYNXjE4np\",\n",
       "  \"hyperparameters\": {\n",
       "    \"n_epochs\": 3\n",
       "  },\n",
       "  \"trained_tokens\": 1754577,\n",
       "  \"error\": null\n",
       "}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.get_current_job()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "667218db-a36c-42c8-aea5-33ad08700215",
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_model = finetune_engine.get_finetuned_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71325911-37f6-4be3-8a18-4ff2cda83cb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OpenAI(callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x176cfca90>, model='ft:gpt-3.5-turbo-0613:llamaindex::85sXTAx1', temperature=0.1, max_tokens=None, additional_kwargs={}, max_retries=10, api_key='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', api_type='open_ai', api_base='https://api.openai.com/v1', api_version='', class_type='openai')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ft_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baf60cc8-6f19-459c-9b7d-7d17b641a8a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use fine-tuned model in RAG system\n",
    "from llama_index import ServiceContext\n",
    "\n",
    "ft_context = ServiceContext.from_defaults(\n",
    "    llm=ft_model,\n",
    "    callback_manager=callback_manager,\n",
    "    system_prompt=(\n",
    "        \"You are a helpful assistant helping to answer questions about the\"\n",
    "        \" Llama 2 paper.\"\n",
    "    ),\n",
    ")\n",
    "# fine-tuned RAG system\n",
    "ft_query_engine = vector_index.as_query_engine(\n",
    "    similarity_top_k=1, service_context=ft_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af3d6a5-f452-489b-8576-dd6aaa25cd6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The margin component is added in the loss of the reward model in Llama 2 by subtracting the reward score of the worse sample from the reward score of the better sample. This difference is then compared to a margin threshold. If the difference is greater than the margin threshold, it is considered a positive example and the loss is set to zero. If the difference is smaller than the margin threshold, it is considered a negative example and the loss is set to the margin threshold minus the difference. This margin component helps to separate the reward scores of the better and worse samples, making the reward model more accurate in distinguishing between them.\n"
     ]
    }
   ],
   "source": [
    "response = ft_query_engine.query(\n",
    "    \"How is the margin component added in the loss of the reward model in\"\n",
    "    \" Llama 2?\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "718fbe91-0b90-4a5d-a9cb-2878eec791d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The margin component is added in the loss of the reward model in Llama 2 by using a preference rating-based margin term. This margin term is used in Equation 2 and helps to separate comparison pairs more effectively. The magnitude of the margin term can be adjusted to achieve better performance on separable pairs, but it may regress performance on similar samples.\n"
     ]
    }
   ],
   "source": [
    "base_query_engine = vector_index.as_query_engine(similarity_top_k=1)\n",
    "base_response = base_query_engine.query(\n",
    "    \"How is the margin component added in the loss of the reward model in\"\n",
    "    \" Llama 2?\"\n",
    ")\n",
    "print(str(base_response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdde395f-7a26-48f1-8c7d-038858866313",
   "metadata": {},
   "source": [
    "## Evaluate Results\n",
    "\n",
    "We run evaluations, over both the validation set but also the training set (as a sanity check)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "343a872b-53ae-4942-a7ce-aa8f80381828",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12acc83f-c006-4189-953c-5a482a53d12f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms import ChatMessage\n",
    "from llama_index.evaluation.eval_utils import get_responses, get_results_df\n",
    "from llama_index.evaluation import BatchEvalRunner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dd58d18-37be-4a1e-8a75-bb49ec834f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs_train.json\")\n",
    "# val_dataset = QueryResponseDataset.from_json(\"data_rag/qa_pairs_val.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6801710-0a42-480b-aac9-9555d03ba8a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "# NOTE: we need to run over the original questions, not the retrieval-augmented questions.\n",
    "# Since our query engines will perform retrieval augmentation under the hood!\n",
    "\n",
    "# TODO: have better code here\n",
    "qr_pairs = load_dataset_from_other_nb(\"data/qa_pairs_2.jsonl\")\n",
    "eval_dataset = QueryResponseDataset.from_qr_pairs(qr_pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae716b1-38f8-43aa-8994-c22a155c75bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate over training dataset for now\n",
    "sample_size = 50\n",
    "\n",
    "eval_qs = eval_dataset.questions[:sample_size]\n",
    "ref_response_strs = [r for (_, r) in eval_dataset.qr_pairs[:sample_size]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "670d8be0-295f-4d64-b6b6-c14351edbb24",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_responses = get_responses(eval_qs, ft_query_engine, show_progress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c8da2ab-6938-4988-b7b1-fb4fe9232539",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_pred_responses = get_responses(\n",
    "    eval_qs, base_query_engine, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eeff4e4-3c2a-4cbd-883d-49bdebcec297",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "pred_response_strs = [str(p) for p in pred_responses]\n",
    "base_pred_response_strs = [str(p) for p in base_pred_responses]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9569fdc-d14f-4d1a-8b32-eedb66e35bec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import (\n",
    "    CorrectnessEvaluator,\n",
    "    SemanticSimilarityEvaluator,\n",
    ")\n",
    "\n",
    "eval_service_context = ServiceContext.from_defaults(llm=OpenAI(model=\"gpt-4\"))\n",
    "# NOTE: can uncomment other evaluators\n",
    "evaluator_c = CorrectnessEvaluator(service_context=eval_service_context)\n",
    "evaluator_s = SemanticSimilarityEvaluator(service_context=eval_service_context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "804b97d9-9484-4347-91c8-bed12fc81c7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator_dict = {\n",
    "    \"correctness\": evaluator_c,\n",
    "    \"semantic_similarity\": evaluator_s,\n",
    "}\n",
    "batch_runner = BatchEvalRunner(evaluator_dict, workers=2, show_progress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c0b9e72-f59e-436c-a507-05f2a0b08f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_results = await batch_runner.aevaluate_responses(\n",
    "    eval_qs, responses=pred_responses, reference=ref_response_strs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67ce45bb-1cc9-4c3c-af52-484773e572ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_eval_results = await batch_runner.aevaluate_responses(\n",
    "    eval_qs, responses=base_pred_responses, reference=ref_response_strs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "468cfdad-75db-4525-bd1d-3d20ae509653",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>names</th>\n",
       "      <th>correctness</th>\n",
       "      <th>semantic_similarity</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>RAG Fine-tuned LLM</td>\n",
       "      <td>3.65</td>\n",
       "      <td>0.941940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Base LLM</td>\n",
       "      <td>3.25</td>\n",
       "      <td>0.917662</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                names  correctness  semantic_similarity\n",
       "0  RAG Fine-tuned LLM         3.65             0.941940\n",
       "1            Base LLM         3.25             0.917662"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_df = get_results_df(\n",
    "    [eval_results, base_eval_results],\n",
    "    [\"RAG Fine-tuned LLM\", \"Base LLM\"],\n",
    "    [\"correctness\", \"semantic_similarity\"],\n",
    ")\n",
    "display(results_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
